{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp models.fedformer"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FEDformer"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The FEDformer model tackles the challenge of finding reliable dependencies on intricate temporal patterns of long-horizon forecasting.\n",
    "\n",
    "The architecture has the following distinctive features:\n",
    "- In-built progressive decomposition in trend and seasonal components based on a moving average filter.\n",
    "- Frequency Enhanced Block and Frequency Enhanced Attention to perform attention in the sparse representation on basis such as Fourier transform.\n",
    "- Classic encoder-decoder proposed by Vaswani et al. (2017) with a multi-head attention mechanism.\n",
    "\n",
    "The FEDformer model utilizes a three-component approach to define its embedding:\n",
    "- It employs encoded autoregressive features obtained from a convolution network.\n",
    "- Absolute positional embeddings obtained from calendar features are utilized."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**References**<br>\n",
    "- [Zhou, Tian, Ziqing Ma, Qingsong Wen, Xue Wang, Liang Sun, and Rong Jin.. \"FEDformer: Frequency enhanced decomposed transformer for long-term series forecasting\"](https://proceedings.mlr.press/v162/zhou22g.html)<br>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Figure 1. FEDformer Architecture.](imgs_models/fedformer.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import numpy as np\n",
    "from typing import Optional\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from neuralforecast.common._modules import DataEmbedding\n",
    "from neuralforecast.common._base_windows import BaseWindows\n",
    "\n",
    "from neuralforecast.losses.pytorch import MAE"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Auxiliary functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class MovingAvg(nn.Module):\n",
    "    \"\"\"\n",
    "    Moving average block to highlight the trend of time series\n",
    "    \"\"\"\n",
    "    def __init__(self, kernel_size, stride):\n",
    "        super(MovingAvg, self).__init__()\n",
    "        self.kernel_size = kernel_size\n",
    "        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # padding on the both ends of time series\n",
    "        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)\n",
    "        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)\n",
    "        x = torch.cat([front, x, end], dim=1)\n",
    "        x = self.avg(x.permute(0, 2, 1))\n",
    "        x = x.permute(0, 2, 1)\n",
    "        return x\n",
    "\n",
    "class SeriesDecomp(nn.Module):\n",
    "    \"\"\"\n",
    "    Series decomposition block\n",
    "    \"\"\"\n",
    "    def __init__(self, kernel_size):\n",
    "        super(SeriesDecomp, self).__init__()\n",
    "        self.MovingAvg = MovingAvg(kernel_size, stride=1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        moving_mean = self.MovingAvg(x)\n",
    "        res = x - moving_mean\n",
    "        return res, moving_mean\n",
    "    \n",
    "class LayerNorm(nn.Module):\n",
    "    \"\"\"\n",
    "    Special designed layernorm for the seasonal part\n",
    "    \"\"\"\n",
    "    def __init__(self, channels):\n",
    "        super(LayerNorm, self).__init__()\n",
    "        self.layernorm = nn.LayerNorm(channels)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x_hat = self.layernorm(x)\n",
    "        bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)\n",
    "        return x_hat - bias\n",
    "\n",
    "\n",
    "class AutoCorrelationLayer(nn.Module):\n",
    "    def __init__(self, correlation, hidden_size, n_head, d_keys=None,\n",
    "                 d_values=None):\n",
    "        super(AutoCorrelationLayer, self).__init__()\n",
    "\n",
    "        d_keys = d_keys or (hidden_size // n_head)\n",
    "        d_values = d_values or (hidden_size // n_head)\n",
    "\n",
    "        self.inner_correlation = correlation\n",
    "        self.query_projection = nn.Linear(hidden_size, d_keys * n_head)\n",
    "        self.key_projection = nn.Linear(hidden_size, d_keys * n_head)\n",
    "        self.value_projection = nn.Linear(hidden_size, d_values * n_head)\n",
    "        self.out_projection = nn.Linear(d_values * n_head, hidden_size)\n",
    "        self.n_head = n_head\n",
    "\n",
    "    def forward(self, queries, keys, values, attn_mask):\n",
    "        B, L, _ = queries.shape\n",
    "        _, S, _ = keys.shape\n",
    "        H = self.n_head\n",
    "\n",
    "        queries = self.query_projection(queries).view(B, L, H, -1)\n",
    "        keys = self.key_projection(keys).view(B, S, H, -1)\n",
    "        values = self.value_projection(values).view(B, S, H, -1)\n",
    "\n",
    "        out, attn = self.inner_correlation(\n",
    "            queries,\n",
    "            keys,\n",
    "            values,\n",
    "            attn_mask\n",
    "        )\n",
    "        out = out.view(B, L, -1)\n",
    "\n",
    "        return self.out_projection(out), attn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class EncoderLayer(nn.Module):\n",
    "    \"\"\"\n",
    "    FEDformer encoder layer with the progressive decomposition architecture\n",
    "    \"\"\"\n",
    "    def __init__(self, attention, hidden_size, conv_hidden_size=None, MovingAvg=25, dropout=0.1, activation=\"relu\"):\n",
    "        super(EncoderLayer, self).__init__()\n",
    "        conv_hidden_size = conv_hidden_size or 4 * hidden_size\n",
    "        self.attention = attention\n",
    "        self.conv1 = nn.Conv1d(in_channels=hidden_size, out_channels=conv_hidden_size, kernel_size=1, bias=False)\n",
    "        self.conv2 = nn.Conv1d(in_channels=conv_hidden_size, out_channels=hidden_size, kernel_size=1, bias=False)\n",
    "        self.decomp1 = SeriesDecomp(MovingAvg)\n",
    "        self.decomp2 = SeriesDecomp(MovingAvg)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.activation = F.relu if activation == \"relu\" else F.gelu\n",
    "\n",
    "    def forward(self, x, attn_mask=None):\n",
    "        new_x, attn = self.attention(\n",
    "            x, x, x,\n",
    "            attn_mask=attn_mask\n",
    "        )\n",
    "        x = x + self.dropout(new_x)\n",
    "        x, _ = self.decomp1(x)\n",
    "        y = x\n",
    "        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))\n",
    "        y = self.dropout(self.conv2(y).transpose(-1, 1))\n",
    "        res, _ = self.decomp2(x + y)\n",
    "        return res, attn\n",
    "\n",
    "\n",
    "class Encoder(nn.Module):\n",
    "    \"\"\"\n",
    "    FEDformer encoder\n",
    "    \"\"\"\n",
    "    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.attn_layers = nn.ModuleList(attn_layers)\n",
    "        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None\n",
    "        self.norm = norm_layer\n",
    "\n",
    "    def forward(self, x, attn_mask=None):\n",
    "        attns = []\n",
    "        if self.conv_layers is not None:\n",
    "            for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):\n",
    "                x, attn = attn_layer(x, attn_mask=attn_mask)\n",
    "                x = conv_layer(x)\n",
    "                attns.append(attn)\n",
    "            x, attn = self.attn_layers[-1](x)\n",
    "            attns.append(attn)\n",
    "        else:\n",
    "            for attn_layer in self.attn_layers:\n",
    "                x, attn = attn_layer(x, attn_mask=attn_mask)\n",
    "                attns.append(attn)\n",
    "\n",
    "        if self.norm is not None:\n",
    "            x = self.norm(x)\n",
    "\n",
    "        return x, attns\n",
    "\n",
    "\n",
    "class DecoderLayer(nn.Module):\n",
    "    \"\"\"\n",
    "    FEDformer decoder layer with the progressive decomposition architecture\n",
    "    \"\"\"\n",
    "    def __init__(self, self_attention, cross_attention, hidden_size, c_out, conv_hidden_size=None,\n",
    "                 MovingAvg=25, dropout=0.1, activation=\"relu\"):\n",
    "        super(DecoderLayer, self).__init__()\n",
    "        conv_hidden_size = conv_hidden_size or 4 * hidden_size\n",
    "        self.self_attention = self_attention\n",
    "        self.cross_attention = cross_attention\n",
    "        self.conv1 = nn.Conv1d(in_channels=hidden_size, out_channels=conv_hidden_size, kernel_size=1, bias=False)\n",
    "        self.conv2 = nn.Conv1d(in_channels=conv_hidden_size, out_channels=hidden_size, kernel_size=1, bias=False)\n",
    "        self.decomp1 = SeriesDecomp(MovingAvg)\n",
    "        self.decomp2 = SeriesDecomp(MovingAvg)\n",
    "        self.decomp3 = SeriesDecomp(MovingAvg)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.projection = nn.Conv1d(in_channels=hidden_size, out_channels=c_out, kernel_size=3, stride=1, padding=1,\n",
    "                                    padding_mode='circular', bias=False)\n",
    "        self.activation = F.relu if activation == \"relu\" else F.gelu\n",
    "\n",
    "    def forward(self, x, cross, x_mask=None, cross_mask=None):\n",
    "        x = x + self.dropout(self.self_attention(\n",
    "            x, x, x,\n",
    "            attn_mask=x_mask\n",
    "        )[0])\n",
    "        x, trend1 = self.decomp1(x)\n",
    "        x = x + self.dropout(self.cross_attention(\n",
    "            x, cross, cross,\n",
    "            attn_mask=cross_mask\n",
    "        )[0])\n",
    "        x, trend2 = self.decomp2(x)\n",
    "        y = x\n",
    "        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))\n",
    "        y = self.dropout(self.conv2(y).transpose(-1, 1))\n",
    "        x, trend3 = self.decomp3(x + y)\n",
    "\n",
    "        residual_trend = trend1 + trend2 + trend3\n",
    "        residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2)\n",
    "        return x, residual_trend\n",
    "\n",
    "\n",
    "class Decoder(nn.Module):\n",
    "    \"\"\"\n",
    "    FEDformer decoder\n",
    "    \"\"\"\n",
    "    def __init__(self, layers, norm_layer=None, projection=None):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.layers = nn.ModuleList(layers)\n",
    "        self.norm = norm_layer\n",
    "        self.projection = projection\n",
    "\n",
    "    def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):\n",
    "        for layer in self.layers:\n",
    "            x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)\n",
    "            trend = trend + residual_trend\n",
    "\n",
    "        if self.norm is not None:\n",
    "            x = self.norm(x)\n",
    "\n",
    "        if self.projection is not None:\n",
    "            x = self.projection(x)\n",
    "        return x, trend"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def get_frequency_modes(seq_len, modes=64, mode_select_method='random'):\n",
    "    \"\"\"\n",
    "    Get modes on frequency domain:\n",
    "        'random' for sampling randomly\n",
    "        'else' for sampling the lowest modes;\n",
    "    \"\"\"\n",
    "    modes = min(modes, seq_len//2)\n",
    "    if mode_select_method == 'random':\n",
    "        index = list(range(0, seq_len // 2))\n",
    "        np.random.shuffle(index)\n",
    "        index = index[:modes]\n",
    "    else:\n",
    "        index = list(range(0, modes))\n",
    "    index.sort()\n",
    "    return index\n",
    "\n",
    "\n",
    "class FourierBlock(nn.Module):\n",
    "    def __init__(self, in_channels, out_channels, seq_len, modes=0, mode_select_method='random'):\n",
    "        super(FourierBlock, self).__init__()\n",
    "        \"\"\"\n",
    "        Fourier block\n",
    "        \"\"\"\n",
    "        # get modes on frequency domain\n",
    "        self.index = get_frequency_modes(seq_len, modes=modes, mode_select_method=mode_select_method)\n",
    "\n",
    "        self.scale = (1 / (in_channels * out_channels))\n",
    "        self.weights1 = nn.Parameter(\n",
    "            self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.cfloat))\n",
    "\n",
    "    # Complex multiplication\n",
    "    def compl_mul1d(self, input, weights):\n",
    "        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)\n",
    "        return torch.einsum(\"bhi,hio->bho\", input, weights)\n",
    "\n",
    "    def forward(self, q, k, v, mask):\n",
    "        # size = [B, L, H, E]\n",
    "        B, L, H, E = q.shape\n",
    "        \n",
    "        x = q.permute(0, 2, 3, 1)\n",
    "        # Compute Fourier coefficients\n",
    "        x_ft = torch.fft.rfft(x, dim=-1)\n",
    "        # Perform Fourier neural operations\n",
    "        out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)\n",
    "        for wi, i in enumerate(self.index):\n",
    "            out_ft[:, :, :, wi] = self.compl_mul1d(x_ft[:, :, :, i], self.weights1[:, :, :, wi])\n",
    "        # Return to time domain\n",
    "        x = torch.fft.irfft(out_ft, n=x.size(-1))\n",
    "        return (x, None)\n",
    "\n",
    "class FourierCrossAttention(nn.Module):\n",
    "    def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=64, mode_select_method='random',\n",
    "                 activation='tanh', policy=0):\n",
    "        super(FourierCrossAttention, self).__init__()\n",
    "        \"\"\"\n",
    "        Fourier Cross Attention layer\n",
    "        \"\"\"\n",
    "        self.activation = activation\n",
    "        self.in_channels = in_channels\n",
    "        self.out_channels = out_channels\n",
    "        # get modes for queries and keys (& values) on frequency domain\n",
    "        self.index_q = get_frequency_modes(seq_len_q, modes=modes, mode_select_method=mode_select_method)\n",
    "        self.index_kv = get_frequency_modes(seq_len_kv, modes=modes, mode_select_method=mode_select_method)\n",
    "\n",
    "        self.scale = (1 / (in_channels * out_channels))\n",
    "        self.weights1 = nn.Parameter(\n",
    "            self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index_q), dtype=torch.cfloat))\n",
    "\n",
    "    # Complex multiplication\n",
    "    def compl_mul1d(self, input, weights):\n",
    "        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)\n",
    "        return torch.einsum(\"bhi,hio->bho\", input, weights)\n",
    "\n",
    "    def forward(self, q, k, v, mask):\n",
    "        # size = [B, L, H, E]\n",
    "        B, L, H, E = q.shape\n",
    "        xq = q.permute(0, 2, 3, 1)  # size = [B, H, E, L]\n",
    "        xk = k.permute(0, 2, 3, 1)\n",
    "        #xv = v.permute(0, 2, 3, 1)\n",
    "\n",
    "        # Compute Fourier coefficients\n",
    "        xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat)\n",
    "        xq_ft = torch.fft.rfft(xq, dim=-1)\n",
    "        for i, j in enumerate(self.index_q):\n",
    "            xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]\n",
    "        xk_ft_ = torch.zeros(B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat)\n",
    "        xk_ft = torch.fft.rfft(xk, dim=-1)\n",
    "        for i, j in enumerate(self.index_kv):\n",
    "            xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]\n",
    "\n",
    "        # Attention mechanism on frequency domain\n",
    "        xqk_ft = (torch.einsum(\"bhex,bhey->bhxy\", xq_ft_, xk_ft_))\n",
    "        if self.activation == 'tanh':\n",
    "            xqk_ft = xqk_ft.tanh()\n",
    "        elif self.activation == 'softmax':\n",
    "            xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)\n",
    "            xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))\n",
    "        else:\n",
    "            raise Exception('{} actiation function is not implemented'.format(self.activation))\n",
    "        xqkv_ft = torch.einsum(\"bhxy,bhey->bhex\", xqk_ft, xk_ft_)\n",
    "        xqkvw = torch.einsum(\"bhex,heox->bhox\", xqkv_ft, self.weights1)\n",
    "        out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)\n",
    "        for i, j in enumerate(self.index_q):\n",
    "            out_ft[:, :, :, j] = xqkvw[:, :, :, i]\n",
    "        \n",
    "        # Return to time domain\n",
    "        out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1))\n",
    "        return (out, None)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class FEDformer(BaseWindows):\n",
    "    \"\"\" FEDformer\n",
    "\n",
    "    The FEDformer model tackles the challenge of finding reliable dependencies on intricate temporal patterns of long-horizon forecasting.\n",
    "\n",
    "    The architecture has the following distinctive features:\n",
    "    - In-built progressive decomposition in trend and seasonal components based on a moving average filter.\n",
    "    - Frequency Enhanced Block and Frequency Enhanced Attention to perform attention in the sparse representation on basis such as Fourier transform.\n",
    "    - Classic encoder-decoder proposed by Vaswani et al. (2017) with a multi-head attention mechanism.\n",
    "\n",
    "    The FEDformer model utilizes a three-component approach to define its embedding:\n",
    "    - It employs encoded autoregressive features obtained from a convolution network.\n",
    "    - Absolute positional embeddings obtained from calendar features are utilized.\n",
    "\n",
    "    *Parameters:*<br>\n",
    "    `h`: int, forecast horizon.<br>\n",
    "    `input_size`: int, maximum sequence length for truncated train backpropagation. Default -1 uses all history.<br>\n",
    "    `futr_exog_list`: str list, future exogenous columns.<br>\n",
    "    `hist_exog_list`: str list, historic exogenous columns.<br>\n",
    "    `stat_exog_list`: str list, static exogenous columns.<br>\n",
    "\t`decoder_input_size_multiplier`: float = 0.5, .<br>\n",
    "    `version`: str = 'Fourier', version of the model.<br>\n",
    "    `modes`: int = 64, number of modes for the Fourier block.<br>\n",
    "    `mode_select`: str = 'random', method to select the modes for the Fourier block.<br>\n",
    "    `hidden_size`: int=128, units of embeddings and encoders.<br>\n",
    "    `dropout`: float (0, 1), dropout throughout Autoformer architecture.<br>\n",
    "    `n_head`: int=8, controls number of multi-head's attention.<br>\n",
    "\t`conv_hidden_size`: int=32, channels of the convolutional encoder.<br>\n",
    "\t`activation`: str=`GELU`, activation from ['ReLU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'PReLU', 'Sigmoid', 'GELU'].<br>\n",
    "    `encoder_layers`: int=2, number of layers for the TCN encoder.<br>\n",
    "    `decoder_layers`: int=1, number of layers for the MLP decoder.<br>\n",
    "    `MovingAvg_window`: int=25, window size for the moving average filter.<br>\n",
    "    `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `valid_loss`: PyTorch module, instantiated validation loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>\n",
    "    `max_steps`: int=1000, maximum number of training steps.<br>\n",
    "    `learning_rate`: float=1e-3, Learning rate between (0, 1).<br>\n",
    "    `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.<br>\n",
    "    `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.<br>\n",
    "    `val_check_steps`: int=100, Number of training steps between every validation loss check.<br>\n",
    "    `batch_size`: int=32, number of different series in each batch.<br>\n",
    "    `valid_batch_size`: int=None, number of different series in each validation and test batch, if None uses batch_size.<br>\n",
    "    `windows_batch_size`: int=1024, number of windows to sample in each training batch, default uses all.<br>\n",
    "    `inference_windows_batch_size`: int=1024, number of windows to sample in each inference batch.<br>\n",
    "    `start_padding_enabled`: bool=False, if True, the model will pad the time series with zeros at the beginning, by input size.<br>\n",
    "    `scaler_type`: str='robust', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).<br>\n",
    "    `random_seed`: int=1, random_seed for pytorch initializer and numpy generators.<br>\n",
    "    `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.<br>\n",
    "    `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>\n",
    "    `alias`: str, optional,  Custom name of the model.<br>\n",
    "    `**trainer_kwargs`: int,  keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>\n",
    "\n",
    "    \"\"\"\n",
    "    # Class attributes\n",
    "    SAMPLING_TYPE = 'windows'\n",
    "\n",
    "    def __init__(self,\n",
    "                 h: int, \n",
    "                 input_size: int,\n",
    "                 stat_exog_list = None,\n",
    "                 hist_exog_list = None,\n",
    "                 futr_exog_list = None,\n",
    "                 decoder_input_size_multiplier: float = 0.5,\n",
    "                 version: str = 'Fourier',\n",
    "                 modes: int = 64,\n",
    "                 mode_select: str = 'random',\n",
    "                 hidden_size: int = 128, \n",
    "                 dropout: float = 0.05,\n",
    "                 n_head: int = 8,\n",
    "                 conv_hidden_size: int = 32,\n",
    "                 activation: str = 'gelu',\n",
    "                 encoder_layers: int = 2, \n",
    "                 decoder_layers: int = 1,\n",
    "                 MovingAvg_window: int = 25,\n",
    "                 loss = MAE(),\n",
    "                 valid_loss = None,\n",
    "                 max_steps: int = 5000,\n",
    "                 learning_rate: float = 1e-4,\n",
    "                 num_lr_decays: int = -1,\n",
    "                 early_stop_patience_steps: int =-1,\n",
    "                 start_padding_enabled = False,\n",
    "                 val_check_steps: int = 100,\n",
    "                 batch_size: int = 32,\n",
    "                 valid_batch_size: Optional[int] = None,\n",
    "                 windows_batch_size = 1024,\n",
    "                 inference_windows_batch_size = 1024,\n",
    "                 step_size: int = 1,\n",
    "                 scaler_type: str = 'identity',\n",
    "                 random_seed: int = 1,\n",
    "                 num_workers_loader: int = 0,\n",
    "                 drop_last_loader: bool = False,\n",
    "                 **trainer_kwargs):\n",
    "        super(FEDformer, self).__init__(h=h,\n",
    "                                       input_size=input_size,\n",
    "                                       hist_exog_list=hist_exog_list,\n",
    "                                       stat_exog_list=stat_exog_list,\n",
    "                                       futr_exog_list = futr_exog_list,\n",
    "                                       loss=loss,\n",
    "                                       valid_loss=valid_loss,\n",
    "                                       max_steps=max_steps,\n",
    "                                       learning_rate=learning_rate,\n",
    "                                       num_lr_decays=num_lr_decays,\n",
    "                                       early_stop_patience_steps=early_stop_patience_steps,\n",
    "                                       val_check_steps=val_check_steps,\n",
    "                                       batch_size=batch_size,\n",
    "                                       windows_batch_size=windows_batch_size,\n",
    "                                       valid_batch_size=valid_batch_size,\n",
    "                                       inference_windows_batch_size=inference_windows_batch_size,\n",
    "                                       start_padding_enabled=start_padding_enabled,\n",
    "                                       step_size=step_size,\n",
    "                                       scaler_type=scaler_type,\n",
    "                                       num_workers_loader=num_workers_loader,\n",
    "                                       drop_last_loader=drop_last_loader,\n",
    "                                       random_seed=random_seed,\n",
    "                                       **trainer_kwargs)\n",
    "        # Architecture\n",
    "        self.futr_input_size = len(self.futr_exog_list)\n",
    "        self.hist_input_size = len(self.hist_exog_list)\n",
    "        self.stat_input_size = len(self.stat_exog_list)\n",
    "\n",
    "        if self.stat_input_size > 0:\n",
    "            raise Exception('Autoformer does not support static variables yet')\n",
    "        \n",
    "        if self.hist_input_size > 0:\n",
    "            raise Exception('Autoformer does not support historical variables yet')\n",
    "\n",
    "        self.label_len = int(np.ceil(input_size * decoder_input_size_multiplier))\n",
    "        if (self.label_len >= input_size) or (self.label_len <= 0):\n",
    "            raise Exception(f'Check decoder_input_size_multiplier={decoder_input_size_multiplier}, range (0,1)')\n",
    "\n",
    "        if activation not in ['relu', 'gelu']:\n",
    "            raise Exception(f'Check activation={activation}')\n",
    "        \n",
    "        if n_head != 8:\n",
    "            raise Exception('n_head must be 8')\n",
    "        \n",
    "        if version not in ['Fourier']:\n",
    "            raise Exception('Only Fourier version is supported currently.')\n",
    "\n",
    "        self.c_out = self.loss.outputsize_multiplier\n",
    "        self.output_attention = False\n",
    "        self.enc_in = 1 \n",
    "        self.dec_in = 1\n",
    "        \n",
    "        self.decomp = SeriesDecomp(MovingAvg_window)\n",
    "\n",
    "        # Embedding\n",
    "        self.enc_embedding = DataEmbedding(c_in=self.enc_in,\n",
    "                                           exog_input_size=self.hist_input_size,\n",
    "                                           hidden_size=hidden_size, \n",
    "                                           pos_embedding=False,\n",
    "                                           dropout=dropout)\n",
    "        self.dec_embedding = DataEmbedding(self.dec_in,\n",
    "                                           exog_input_size=self.hist_input_size,\n",
    "                                           hidden_size=hidden_size, \n",
    "                                           pos_embedding=False,\n",
    "                                           dropout=dropout)\n",
    "\n",
    "        encoder_self_att = FourierBlock(in_channels=hidden_size,\n",
    "                                        out_channels=hidden_size,\n",
    "                                        seq_len=input_size,\n",
    "                                        modes=modes,\n",
    "                                        mode_select_method=mode_select)\n",
    "        decoder_self_att = FourierBlock(in_channels=hidden_size,\n",
    "                                        out_channels=hidden_size,\n",
    "                                        seq_len=input_size//2+self.h,\n",
    "                                        modes=modes,\n",
    "                                        mode_select_method=mode_select)\n",
    "        decoder_cross_att = FourierCrossAttention(in_channels=hidden_size,\n",
    "                                                    out_channels=hidden_size,\n",
    "                                                    seq_len_q=input_size//2+self.h,\n",
    "                                                    seq_len_kv=input_size,\n",
    "                                                    modes=modes,\n",
    "                                                    mode_select_method=mode_select)\n",
    "\n",
    "        self.encoder = Encoder(\n",
    "            [\n",
    "                EncoderLayer(\n",
    "                    AutoCorrelationLayer(\n",
    "                        encoder_self_att,\n",
    "                        hidden_size, n_head),\n",
    "\n",
    "                    hidden_size=hidden_size,\n",
    "                    conv_hidden_size=conv_hidden_size,\n",
    "                    MovingAvg=MovingAvg_window,\n",
    "                    dropout=dropout,\n",
    "                    activation=activation\n",
    "                ) for l in range(encoder_layers)\n",
    "            ],\n",
    "            norm_layer=LayerNorm(hidden_size)\n",
    "        )\n",
    "        # Decoder\n",
    "        self.decoder = Decoder(\n",
    "            [\n",
    "                DecoderLayer(\n",
    "                    AutoCorrelationLayer(\n",
    "                        decoder_self_att,\n",
    "                        hidden_size, n_head),\n",
    "                    AutoCorrelationLayer(\n",
    "                        decoder_cross_att,\n",
    "                        hidden_size, n_head),\n",
    "                    hidden_size=hidden_size,\n",
    "                    c_out=self.c_out,\n",
    "                    conv_hidden_size=conv_hidden_size,\n",
    "                    MovingAvg=MovingAvg_window,\n",
    "                    dropout=dropout,\n",
    "                    activation=activation,\n",
    "                )\n",
    "                for l in range(decoder_layers)\n",
    "            ],\n",
    "            norm_layer=LayerNorm(hidden_size),\n",
    "            projection=nn.Linear(hidden_size, self.c_out, bias=True)\n",
    "        )\n",
    "\n",
    "    def forward(self, windows_batch):\n",
    "        # Parse windows_batch\n",
    "        insample_y    = windows_batch['insample_y']\n",
    "        #insample_mask = windows_batch['insample_mask']\n",
    "        #hist_exog     = windows_batch['hist_exog']\n",
    "        #stat_exog     = windows_batch['stat_exog']\n",
    "        futr_exog     = windows_batch['futr_exog']\n",
    "\n",
    "        # Parse inputs\n",
    "        insample_y = insample_y.unsqueeze(-1) # [Ws,L,1]\n",
    "        if self.futr_input_size > 0:\n",
    "            x_mark_enc = futr_exog[:,:self.input_size,:]\n",
    "            x_mark_dec = futr_exog[:,-(self.label_len+self.h):,:]\n",
    "        else:\n",
    "            x_mark_enc = None\n",
    "            x_mark_dec = None\n",
    "\n",
    "        x_dec = torch.zeros(size=(len(insample_y),self.h, self.dec_in)).to(insample_y.device)\n",
    "        x_dec = torch.cat([insample_y[:,-self.label_len:,:], x_dec], dim=1)\n",
    "                \n",
    "        # decomp init\n",
    "        mean = torch.mean(insample_y, dim=1).unsqueeze(1).repeat(1, self.h, 1)\n",
    "        zeros = torch.zeros([x_dec.shape[0], self.h, x_dec.shape[2]], device=insample_y.device)\n",
    "        seasonal_init, trend_init = self.decomp(insample_y)\n",
    "        # decoder input\n",
    "        trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1)\n",
    "        seasonal_init = torch.cat([seasonal_init[:, -self.label_len:, :], zeros], dim=1)\n",
    "        # enc\n",
    "        enc_out = self.enc_embedding(insample_y, x_mark_enc)\n",
    "        enc_out, attns = self.encoder(enc_out, attn_mask=None)\n",
    "        # dec\n",
    "        dec_out = self.dec_embedding(seasonal_init, x_mark_dec)\n",
    "        seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None,\n",
    "                                                 trend=trend_init)\n",
    "        # final\n",
    "        dec_out = trend_part + seasonal_part\n",
    "\n",
    "        forecast = self.loss.domain_map(dec_out[:, -self.h:])\n",
    "        return forecast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| eval: false\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pytorch_lightning as pl\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from neuralforecast import NeuralForecast\n",
    "from neuralforecast.models import MLP\n",
    "from neuralforecast.losses.pytorch import MQLoss, DistributionLoss, MSE\n",
    "from neuralforecast.tsdataset import TimeSeriesDataset\n",
    "from neuralforecast.utils import AirPassengers, AirPassengersPanel, AirPassengersStatic, augment_calendar_df\n",
    "\n",
    "AirPassengersPanel, calendar_cols = augment_calendar_df(df=AirPassengersPanel, freq='M')\n",
    "\n",
    "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train\n",
    "Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| eval: false\n",
    "model = FEDformer(h=12,\n",
    "                 input_size=24,\n",
    "                 modes=64,\n",
    "                 hidden_size=64,\n",
    "                 conv_hidden_size=128,\n",
    "                 n_head=8,\n",
    "                 loss=MAE(),\n",
    "                 futr_exog_list=calendar_cols,\n",
    "                 scaler_type='robust',\n",
    "                 learning_rate=1e-3,\n",
    "                 max_steps=500,\n",
    "                 batch_size=2,\n",
    "                 windows_batch_size=32,\n",
    "                 val_check_steps=50,\n",
    "                 early_stop_patience_steps=2)\n",
    "\n",
    "nf = NeuralForecast(\n",
    "    models=[model],\n",
    "    freq='M',\n",
    ")\n",
    "nf.fit(df=Y_train_df, static_df=None, val_size=12)\n",
    "forecasts = nf.predict(futr_df=Y_test_df)\n",
    "\n",
    "Y_hat_df = forecasts.reset_index(drop=False).drop(columns=['unique_id','ds'])\n",
    "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n",
    "plot_df = pd.concat([Y_train_df, plot_df])\n",
    "\n",
    "if model.loss.is_distribution_output:\n",
    "    plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
    "    plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "    plt.plot(plot_df['ds'], plot_df['FEDformer-median'], c='blue', label='median')\n",
    "    plt.fill_between(x=plot_df['ds'][-12:], \n",
    "                    y1=plot_df['FEDformer-lo-90'][-12:].values, \n",
    "                    y2=plot_df['FEDformer-hi-90'][-12:].values,\n",
    "                    alpha=0.4, label='level 90')\n",
    "    plt.grid()\n",
    "    plt.legend()\n",
    "    plt.plot()\n",
    "else:\n",
    "    plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
    "    plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
    "    plt.plot(plot_df['ds'], plot_df['FEDformer'], c='blue', label='Forecast')\n",
    "    plt.legend()\n",
    "    plt.grid()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
